package top.cyuw.simplerpc.remoting.client;

import io.netty.bootstrap.Bootstrap;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.timeout.IdleStateHandler;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import top.cyuw.simplerpc.constant.RpcConstants;
import top.cyuw.simplerpc.context.RpcContext;
import top.cyuw.simplerpc.exception.RpcException;
import top.cyuw.simplerpc.extension.ExtensionLoader;
import top.cyuw.simplerpc.factory.SingletonFactory;
import top.cyuw.simplerpc.registry.ServiceDiscovery;
import top.cyuw.simplerpc.registry.ServiceDiscoverySelector;
import top.cyuw.simplerpc.remoting.RpcRemoting;
import top.cyuw.simplerpc.remoting.codec.RpcMessageDecoder;
import top.cyuw.simplerpc.remoting.codec.RpcMessageEncoder;
import top.cyuw.simplerpc.dto.RpcMessage;
import top.cyuw.simplerpc.dto.RpcRequest;
import top.cyuw.simplerpc.dto.RpcResponse;
import top.cyuw.simplerpc.util.NamedThreadFactory;

import java.net.InetSocketAddress;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

/**
 * @author chen
 * @date 2023/3/13 11:58
 */
@Slf4j
public class SimpleRpcClient implements RpcRemoting {

    private final RpcContext context;
    private final AtomicBoolean started = new AtomicBoolean(false);
    @Getter
    private final ServiceDiscovery serviceDiscovery;
    private final ChannelManager channelManager = SingletonFactory.getInstance(ChannelManager.class);
    private final FutureManager futureManager = SingletonFactory.getInstance(FutureManager.class);
    private EventLoopGroup eventLoopGroup;
    private Bootstrap bootstrap;

    public SimpleRpcClient() {
        context = RpcContext.getInstance();
        String serviceDiscoveryName = ServiceDiscoverySelector.getServiceDiscovery();
        serviceDiscovery = ExtensionLoader.of(ServiceDiscovery.class).getExtension(serviceDiscoveryName);
    }

    public CompletableFuture<RpcResponse> request(RpcRequest rpcRequest) {
        rpcRequest.setRequestId(UUID.randomUUID().toString());
        CompletableFuture<RpcResponse> future = new CompletableFuture<>();

        InetSocketAddress address = serviceDiscovery.lookupService(rpcRequest);
        if (address == null) {
            throw new RpcException("could not found available service: " + rpcRequest.getFullServiceName());
        }

        Channel channel = getChannel(address);
        if (!channel.isActive()) {
            throw new RpcException("channel is not active when request.");
        }
        futureManager.set(rpcRequest.getRequestId(), future);
        RpcMessage rpcMessage = RpcMessage.builder()
                .messageType(RpcConstants.MESSAGE_TYPE_RPC_REQUEST)
                .body(rpcRequest).build();
        channel.writeAndFlush(rpcMessage).addListener((ChannelFutureListener) f -> {
            if (!f.isSuccess()) {
                futureManager.remove(rpcRequest.getRequestId());
                f.channel().close();
                future.completeExceptionally(f.cause());
            }
        });
        return future;
    }

    public Channel getChannel(InetSocketAddress address) {
        Channel channel = channelManager.get(address);
        if (channel == null) {
            try {
                channel = doConnect(address);
                channelManager.set(address, channel);
            } catch (Exception e) {
                log.error("connect to " + address + " failed: " + e.getMessage(), e);
            }
        }
        return channel;
    }

    private Channel doConnect(InetSocketAddress address) throws ExecutionException, InterruptedException {
        CompletableFuture<Channel> future = new CompletableFuture<>();
        bootstrap.connect(address).addListener((ChannelFutureListener) f -> {
            if (f.isSuccess()) {
                future.complete(f.channel());
            } else {
                future.completeExceptionally(f.cause());
            }
        });
        return future.get();
    }

    @Override
    public void start() {
        if (started.compareAndSet(false, true)) {
            doStart();
        }
    }

    private void doStart() {
        eventLoopGroup = new NioEventLoopGroup(new NamedThreadFactory("simplerpc-client-nio"));
        bootstrap = new Bootstrap();
        bootstrap.group(eventLoopGroup)
                .channel(NioSocketChannel.class)
                .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 5000)
                .handler(new ChannelInitializer<Channel>() {
                    @Override
                    protected void initChannel(Channel ch) {
                        ch.pipeline().addLast(new IdleStateHandler(0, 5, 0, TimeUnit.SECONDS))
                                .addLast(new RpcMessageEncoder())
                                .addLast(new RpcMessageDecoder())
                                .addLast(new SimpleRpcClientHandler(futureManager));
                    }
                });
    }

    @Override
    public void stop() {
        if (started.compareAndSet(true, false)) {
            doStop();
        }
    }

    private void doStop() {
        if (null != eventLoopGroup) {
            eventLoopGroup.shutdownGracefully().awaitUninterruptibly();
        }
    }

    @Override
    public RpcContext getContext() {
        return context;
    }
}
